"""
---
title: "Parity Task"
summary: >
  This creates data for Parity Task from the paper Adaptive Computation Time
  for Recurrent Neural Networks
---

# Parity Task

This creates data for Parity Task from the paper
[Adaptive Computation Time for Recurrent Neural Networks](https://arxiv.org/abs/1603.08983).

The input of the parity task is a vector with $0$'s $1$'s and $-1$'s.
The output is the parity of $1$'s - one if there is an odd number of $1$'s and zero otherwise.
The input is generated by making a random number of elements in the vector either $1$ or $-1$'s.
"""

from typing import Tuple

import torch
from torch.utils.data import Dataset


class ParityDataset(Dataset):
    """
    ### Parity dataset
    """

    def __init__(self, n_samples: int, n_elems: int = 64):
        """
        * `n_samples` is the number of samples
        * `n_elems` is the number of elements in the input vector
        """
        self.n_samples = n_samples
        self.n_elems = n_elems

    def __len__(self):
        """
        Size of the dataset
        """
        return self.n_samples

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate a sample
        """

        # Empty vector
        x = torch.zeros((self.n_elems,))
        # Number of non-zero elements - a random number between $1$ and total number of elements
        n_non_zero = torch.randint(1, self.n_elems + 1, (1,)).item()
        # Fill non-zero elements with $1$'s and $-1$'s
        x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
        # Randomly permute the elements
        x = x[torch.randperm(self.n_elems)]

        # The parity
        y = (x == 1.).sum() % 2

        #
        return x, y
